from collections.abc import Sequence
from dataclasses import dataclass

from xdsl.builder import Builder
from xdsl.context import Context
from xdsl.dialect_interfaces.constant_materialization import (
    ConstantMaterializationInterface,
)
from xdsl.ir import Attribute, Operation, SSAValue, TypeAttribute
from xdsl.pattern_rewriter import PatternRewriter
from xdsl.traits import HasFolder


@dataclass
class Folder:
    context: Context

    def try_fold(self, op: Operation) -> tuple[list[SSAValue], list[Operation]] | None:
        """
        Try to fold the given operation.
        Returns a tuple with the list of SSAValues that replace the results of the operation,
        and a list of constant operations that were created from the constant attributes generated by the folder.
        If the operation could not be folded, returns None.

        Note that while this folds only one operation, multiple new operations can be created.
        Each of the results of the original operation might be replaced by a new constant operation.
        """

        if (trait := op.get_trait(HasFolder)) is None:
            return None
        folded = trait.fold(op)
        if folded is None:
            return None
        results: list[SSAValue] = []
        new_ops: list[Operation] = []
        for val, original_result in zip(folded, op.results, strict=True):
            if isinstance(val, SSAValue):
                results.append(val)
            else:
                assert isinstance(val, Attribute)
                dialect = self.context.get_dialect(op.dialect_name())
                interface = dialect.get_interface(ConstantMaterializationInterface)
                if interface is None:
                    return None
                assert isinstance(type := original_result.type, TypeAttribute)
                new_op = interface.materialize_constant(val, type)
                if new_op is None:
                    return None
                new_ops.append(new_op)
                results.append(new_op.results[0])
        return results, new_ops

    def insert_with_fold(
        self, op: Operation, builder: Builder
    ) -> Sequence[SSAValue] | None:
        """
        Inserts the operation using the provided builder, trying to fold it first.
        If folding is successful, the folded results are returned, otherwise None is returned.
        """
        results = self.try_fold(op)
        if results is None:
            builder.insert(op)
            return op.results
        else:
            if op.parent:
                raise ValueError(
                    "Can't insert_with_fold fold an operation that already has a parent."
                )
            values, new_ops = results
            builder.insert_op(new_ops)
            return values

    def replace_with_fold(
        self, op: Operation, rewriter: PatternRewriter, safe_erase: bool = True
    ) -> Sequence[SSAValue] | None:
        """
        Replaces the operation with its folded results.
        If folding is successful, the folded results are returned.
        Otherwise, returns None.
        """
        results = self.try_fold(op)
        if results is None:
            return None
        values, new_ops = results
        rewriter.replace_op(op, new_ops, values, safe_erase)
        return values
